This notebook is a modification of [this notebook on Neural Machine Translation by Jointly Learning to Align and Translate.](https://github.com/bentrevett/pytorch-seq2seq/blob/master/3%20-%20Neural%20Machine%20Translation%20by%20Jointly%20Learning%20to%20Align%20and%20Translate.ipynb)

# 3 - Neural Machine Translation with Attention

In this third notebook on sequence-to-sequence models using PyTorch and TorchText, we'll be implementing neural machine translation models with several variants of attention. We will ask you along the way to implement some of the key components of this model. Training all these models with the full dataset can take a long time. If you don't have access to the necessary compute, it's fine to just use a subset of the data, or train for fewer epochs.

## Introduction

As a reminder, here is the general encoder-decoder model:

![](seq2seq1.png)

In the vanilla Seq2Seq model, the architecture is set-up in a way to reduce "information compression" by explicitly passing the context vector, $z$, to the decoder at every time-step and by passing both the context vector and input word, $y_t$, along with the hidden state, $s_t$, to the linear layer, $f$, to make a prediction.

![](seq2seq7.png)

Even though we have reduced some of this compression, our context vector still needs to contain all of the information about the source sentence. The model implemented in this notebook avoids this compression by allowing the decoder to look at the entire source sentence (via its hidden states) at each decoding step! How does it do this? It uses *attention*. 

Attention works by first, calculating an attention vector, $a$, that is the length of the source sentence. The attention vector has the property that each element is between 0 and 1, and the entire vector sums to 1 ($\sum_{i}a_i = 1$). We then calculate a weighted sum of our source sentence hidden states, $H$, to get a weighted source vector, $w$. 

$$w = \sum_{i}a_ih_i$$

We calculate a new weighted source vector every time-step when decoding, using it as input to our decoder RNN as well as the linear layer to make a prediction. We'll explain how to do all of this during the tutorial.

# Python3 environment requirements
This notebook was tested with Python 3.11.4 and the following library versions:
* torch==2.1.0
* torchtext==0.16.0
* spacy==3.5.3
* numpy==1.24.3
* portalocker>=2.0.0

In [1]:
import os
import torch
import random
import getpass
import numpy as np

def seed(seed = 1810):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# set seed for reproducibility
SEED = 1810
seed(SEED)

## Installing the Requirements

In [2]:
!python3 -m spacy download en_core_web_sm
!python3 -m spacy download de_core_news_sm

Collecting en-core-web-sm==3.5.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.5.0/en_core_web_sm-3.5.0-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m78.9 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
Collecting de-core-news-sm==3.5.0
  Downloading https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.5.0/de_core_news_sm-3.5.0-py3-none-any.whl (14.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.6/14.6 MB[0m [31m40.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('de_core_news_sm')


## Preparing Data

First we import all the required modules.

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

from torchtext.datasets import Multi30k
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer

from tqdm import tqdm

import spacy

import random
import math
import time

We create the tokenizers using German and English spaCy models.

In [7]:
src_tokenizer = get_tokenizer('spacy','de_core_news_sm' )
tgt_tokenizer = get_tokenizer('spacy','en_core_web_sm' )

Load the data.

In [8]:
from torchtext.data.functional import to_map_style_dataset

train_iter, valid_iter = Multi30k(split=('train', 'valid'))
train_list, valid_list = to_map_style_dataset(train_iter), to_map_style_dataset(valid_iter) 

Next, we'll build the *vocabulary* for the source and target languages. The vocabulary is used to associate each unique token with an index (an integer). The vocabularies of the source and target languages are distinct.

Using the `min_freq` argument, we only allow tokens that appear at least 2 times to appear in our vocabulary. Tokens that appear only once are converted into an `<unk>` (unknown) token.

It is important to note that our vocabulary should only be built from the training set and not the validation/test set. This prevents "information leakage" into our model, giving us artifically inflated validation/test scores.

The `set_default_index` function would return the `<unk>` token indice for out of vocabulary (OOV) words. 

In [9]:
def transformed(fn, index, data):
    #create a generator for either source or target sentences
    for line in data:
        yield fn(line[index])

src_vocab = build_vocab_from_iterator(transformed(src_tokenizer, 0, train_iter), specials=('<pad>', '<unk>', '<sos>', '<eos>'), min_freq=2)
tgt_vocab = build_vocab_from_iterator(transformed(tgt_tokenizer, 1, train_iter), specials=('<pad>', '<unk>', '<sos>', '<eos>'), min_freq=2)
src_vocab.set_default_index(src_vocab['<unk>'])
tgt_vocab.set_default_index(tgt_vocab['<unk>'])



Define the device.

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Define transform pipeline for preparing the sequences for the model.

In [11]:
#transform each sentence into a sequence of token_ids  
def text_transform(sentence, vocab, tokenizer):
    return [vocab['<sos>']] + [vocab[token] for token in tokenizer(sentence)] + [vocab['<eos>']]

Create the iterators using pytorch `Dataloader`. We define custom `collate_fn` and `batch_sampler` for handling textual inputs. 

In [12]:
#prepare a batch for the model (i.e., transform the sequences into token_ids and pad them to have similar lengths)
def collate_batch(batch):
   src_list, tgt_list = [], []
   for (_src, _tgt) in batch:
        src_list.append(torch.tensor(text_transform(_src, src_vocab, src_tokenizer), device=device))
        tgt_list.append(torch.tensor(text_transform(_tgt, tgt_vocab, tgt_tokenizer), device=device))
   return pad_sequence(src_list), pad_sequence(tgt_list)


#group sequences with similar lenghts together
def batch_sampler(iterator):
    indices = [(i, len(src_tokenizer(s[0]))) for i, s in enumerate(iterator)]
    random.shuffle(indices)
    pooled_indices = []
    # create pool of indices with similar lengths 
    for i in range(0, len(indices), BATCH_SIZE * 100):
        pooled_indices.extend(sorted(indices[i:i + BATCH_SIZE * 100], key=lambda x: x[1]))

    pooled_indices = [x[0] for x in pooled_indices]

    # create indices for current batch
    batch_indices = []
    for i in range(0, len(pooled_indices), BATCH_SIZE):
        batch_indices.append(pooled_indices[i:i + BATCH_SIZE])
    return batch_indices


BATCH_SIZE = 8  # A batch size of 8

train_dataloader, valid_dataloader = [DataLoader(iterator, batch_sampler=batch_sampler(iterator),
                               collate_fn=collate_batch) for iterator in [train_list[:256], valid_list]]

## Building the Seq2Seq Model

### Encoder

First, we'll build the encoder. We use a *bidirectional RNN*. With a bidirectional RNN, we have two RNNs in each layer. A *forward RNN* going over the sentence from left to right (shown below in green), and a *backward RNN* going over the sentence from right to left (yellow). All we need to do in code is set `bidirectional = True` and then pass the embedded sentence to the RNN as before. 

![](seq2seq8.png)

We now have:

$$\begin{align*}
h_t^\rightarrow &= \text{EncoderGRU}^\rightarrow(x_t^\rightarrow,h_t^\rightarrow)\\
h_t^\leftarrow &= \text{EncoderGRU}^\leftarrow(x_t^\leftarrow,h_t^\leftarrow)
\end{align*}$$

Where $x_0^\rightarrow = \text{<sos>}, x_1^\rightarrow = \text{guten}$ and $x_0^\leftarrow = \text{<eos>}, x_1^\leftarrow = \text{morgen}$.

As before, we only pass an input (`embedded`) to the RNN, which tells PyTorch to initialize both the forward and backward initial hidden states ($h_0^\rightarrow$ and $h_0^\leftarrow$, respectively) to a tensor of all zeros. We'll also get two context vectors, one from the forward RNN after it has seen the final word in the sentence, $z^\rightarrow=h_T^\rightarrow$, and one from the backward RNN after it has seen the first word in the sentence, $z^\leftarrow=h_T^\leftarrow$.

The RNN returns `outputs` and `hidden`. 

`outputs` is of size **[src sent len, batch size, hid dim * num directions]** where the first `hid_dim` elements in the third axis are the hidden states from the top layer forward RNN, and the last `hid_dim` elements are hidden states from the top layer backward RNN. You can think of the third axis as being the forward and backward hidden states stacked on top of each other, i.e. $h_1 = [h_1^\rightarrow; h_{T}^\leftarrow]$, $h_2 = [h_2^\rightarrow; h_{T-1}^\leftarrow]$ and we can denote all stacked encoder hidden states as $H=\{ h_1, h_2, ..., h_T\}$.

`hidden` is of size **[n layers * num directions, batch size, hid dim]**, where **[-2, :, :]** gives the top layer forward RNN hidden state after the final time-step (i.e. after it has seen the last word in the sentence) and **[-1, :, :]** gives the top layer backward RNN hidden state after the final time-step (i.e. after it has seen the first word in the sentence).

As the decoder is not bidirectional, it only needs a single context vector, $z$, to use as its initial hidden state, $s_0$, and we currently have two, a forward and a backward one ($z^\rightarrow=h_T^\rightarrow$ and $z^\leftarrow=h_T^\leftarrow$, respectively). We solve this by concatenating the two context vectors together, passing them through a linear layer, $g$, and applying the $\tanh$ activation function. 

$$z=\tanh(g(h_T^\rightarrow, h_T^\leftarrow)) = \tanh(g(z^\rightarrow, z^\leftarrow)) = s_0$$

As we want our model to look back over the whole of the source sentence we return `outputs`, the stacked forward and backward hidden states for every token in the source sentence. We also return `hidden`, which acts as our initial hidden state in the decoder.

In [13]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()
        
        self.embedding = nn.Embedding(input_dim, emb_dim)
        
        self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)
        
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src):
        
        #src = [src sent len, batch size]
        
        embedded = self.dropout(self.embedding(src))
        
        #embedded = [src sent len, batch size, emb dim]
        
        outputs, hidden = self.rnn(embedded)
                
        #outputs = [src sent len, batch size, hid dim * num directions]
        #hidden = [n layers * num directions, batch size, hid dim]
        
        #hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]
        #outputs are always from the last layer
        
        #hidden [-2, :, : ] is the last of the forwards RNN 
        #hidden [-1, :, : ] is the last of the backwards RNN
        
        #initial decoder hidden is final hidden state of the forwards and backwards 
        #  encoder RNNs fed through a linear layer
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))
        
        #outputs = [src sent len, batch size, enc hid dim * 2]
        #hidden = [batch size, dec hid dim]
        
        return outputs, hidden

### Part 1: Implement Attention

Next up is the attention layer. This will take in the previous hidden state of the decoder, $s_{t-1}$, and all of the stacked forward and backward hidden states from the encoder, $H$. The layer will output an attention vector, $a_t$, that is the length of the source sentence, each element is between 0 and 1 and the entire vector sums to 1.

Intuitively, this layer takes what we have decoded so far, $s_{t-1}$, and all of what we have encoded, $H$, to produce a vector, $a_t$, that represents which words in the source sentence we should pay the most attention to in order to correctly predict the next word to decode, $\hat{y}_{t+1}$.

We explore three different variants of computing $\hat{a_t}$, which are the logits for the softmax operation that gives us the attention weights $a_t = \text{softmax}(\hat{a_t})$.

Graphically, this looks something like below. This is for calculating the very first attention vector, where $s_{t-1} = s_0 = z$. The green/yellow blocks represent the hidden states from both the forward and backward RNNs, and the attention computation is all done within the pink block.

![](seq2seq9.png)

The first attention variant is *additive* attention, which was also used in the original paper [Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473):

$$\hat{a_t}_i = v\tanh(W_a([s_{t-1}; h_i]))$$

where $v, W_a$ are a learnable vector and matrix, respectively.

The second variant is called *multiplicative* attention. It is sometimes favorable because it has fewer learnable parameters (i.e., it is more memory efficient) and is faster to compute. It looks as follows:

$$\hat{a_t}_i = h_i^T W_a s_{t-1}$$

Finally, we would like to understand how attention can be used if we don't have contextual information like the hidden state of the decoder $s_{t-1}$ in NMT. In this case, we only use the encoder hidden state $h_i$ itself to compute the attention weights, which is called *self-attention*:

$$\hat{a_t}_i = v\tanh(W_a h_i])$$

Self-attention is not really practically useful in sequence-to-sequence tasks, because here you usually have contextual information from the decoder. However, self-attention is the crucial building block of Transformer models, which you will learn about in the next lecture.

Note that many, many more variants of attention have been proposed. For a categorization attempt, you may refer to [this paper](https://arxiv.org/pdf/1711.07341.pdf).

Below, we would like you to implement the three variants of attention. Each attention function gets as input the decoder's hidden state and the encoder outputs, and is supposed to compute the logits $\hat{a_t}$.

Note that in this implementation you don't need to necessarily take proper care of not attending to elements in the input that are invalid. This is not a problem here, because due to the custom batch_sampler most sentences in a batch have the same or a similar length. In order to see how to properly take care of invalid inputs, please refer to [this notebook](https://github.com/bentrevett/pytorch-seq2seq/blob/master/4%20-%20Packed%20Padded%20Sequences%2C%20Masking%20and%20Inference.ipynb).

In [14]:
class Attention(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, hidden, encoder_outputs):
        
        #hidden = [batch size, dec hid dim] or [batch size, seq_len, dec hid dim]
        #encoder_outputs = [src sent len, batch size, enc hid dim * 2]
                
        attention = self._compute_logits(hidden, encoder_outputs)
        
        #attention= [batch size, src len]
        
        return F.softmax(attention, dim=1)
    
    def _compute_logits(self, hidden, encoder_ouputs):
        pass

Now we expect you to implement the `_compute_logits()` function for each of the Attention classes. 

$$\hat{a_t}_i = v\tanh(W_a([s_{t-1}; h_i]))$$

In [16]:
class AdditiveAttention(Attention):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        
        # define the learnable parameters that you need here
        self.Wa = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)
        
    def _compute_logits(self, hidden, encoder_outputs):
        
        #hidden = [batch size, dec hid dim] or [batch size, seq_len, dec hid dim]
        #encoder_outputs = [src sent len, batch size, enc hid dim * 2]
        
        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]
        
        # TODO: Implement me        
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        
        energy = torch.tanh(self.Wa(torch.cat((hidden, encoder_outputs), dim = 2))) 
        
        logits = self.v(energy).squeeze(2)
        
        return logits

$$\hat{a_t}_i = h_i^T W_a s_{t-1}$$

In [17]:
class MultiplicativeAttention(Attention):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        
        # define the learnable parameters that you need here
        self.Wa = nn.Linear(dec_hid_dim, enc_hid_dim*2)
        
    def _compute_logits(self, hidden, encoder_outputs):
        
        #hidden = [batch size, dec hid dim] or [batch size, seq_len, dec hid dim]
        #encoder_outputs = [src sent len, batch size, enc hid dim * 2]
        
        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]
        
        # TODO: Implement me
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)       
        
        logits = torch.sum(torch.multiply(encoder_outputs , self.Wa(hidden)), dim = 2)  
        
        return logits

$$\hat{a_t}_i = v\tanh(W_a h_i])$$

In [18]:
class SelfAttention(Attention):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        
        # define the learnable parameters that you need here.
        self.Wa = nn.Linear(2*enc_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)
        
    def _compute_logits(self, hidden, encoder_outputs):
        
        #hidden = [batch size, dec hid dim] or [batch size, seq_len, dec hid dim]
        #encoder_outputs = [src sent len, batch size, enc hid dim * 2]
        
        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]
        
        # TODO: Implement me
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
 
        energy = torch.tanh(self.Wa(encoder_outputs))
        logits = self.v(energy).squeeze(2)
        
        return logits

### Part 2: Decoder

Next up is the decoder. 

The decoder contains the attention layer, `attention`, which takes the previous hidden state, $s_{t-1}$, all of the encoder hidden states, $H$, and returns the attention vector, $a_t$.

We then use this attention vector to create a weighted source vector, $w_t$, denoted by `weighted`, which is a weighted sum of the encoder hidden states, $H$, using $a_t$ as the weights.

$$w_t = a_t H$$

The input word (that has been embedded), $y_t$, the weighted source vector, $w_t$, and the previous decoder hidden state, $s_{t-1}$, are then all passed into the decoder RNN, with $y_t$ and $w_t$ being concatenated together.

$$s_t = \text{DecoderGRU}(y_t, w_t, s_{t-1})$$

We then pass $y_t$, $w_t$ and $s_t$ through the linear layer, $f$, to make a prediction of the next word in the target sentence, $\hat{y}_{t+1}$. This is done by concatenating them all together.

$$\hat{y}_{t+1} = f(y_t, w_t, s_t)$$

The image below shows decoding the first word in an example translation.

![](seq2seq10.png)

The green/yellow blocks show the forward/backward encoder RNNs which output $H$, the red block shows the context vector, $z = h_T = \tanh(g(h^\rightarrow_T,h^\leftarrow_T)) = \tanh(g(z^\rightarrow, z^\leftarrow)) = s_0$, the blue block shows the decoder RNN which outputs $s_t$, the purple block shows the linear layer, $f$, which outputs $\hat{y}_{t+1}$ and the orange block shows the calculation of the weighted sum over $H$ by $a_t$ and outputs $w_t$. Not shown is the calculation of $a_t$.

Below, we ask you to implement the described model given the components initialized in the init function.

In [19]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
        super().__init__()

        self.output_dim = output_dim
        self.attention = attention
        
        self.embedding = nn.Embedding(output_dim, emb_dim)
        
        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
        
        self.out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input_tgt, hidden, encoder_outputs):
             
        #input_tgt = [batch size]
        #hidden = [batch size, dec hid dim]
        #encoder_outputs = [src sent len, batch size, enc hid dim * 2]
        
        input_tgt = input_tgt.unsqueeze(0)
        
        #input = [1, batch size]
        
        embedded = self.dropout(self.embedding(input_tgt))
        
        #embedded = [1, batch size, emb dim]
        
        # TODO: Compute the output and the new hidden state of the decoder
        
        a = self.attention(hidden, encoder_outputs)
        a = a.unsqueeze(1)
        
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        
        weighted = torch.bmm(a, encoder_outputs)
        weighted = weighted.permute(1, 0, 2)
        
        rnn_input = torch.cat((embedded, weighted), dim = 2)
        
        output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
                
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        hidden = hidden.squeeze(0)
        
        output = self.out(torch.cat((output, weighted, embedded), dim = 1))
        
        #output = [batch size, output dim]
        #hidden = [batch size, dec hid dim]
        
        return output, hidden

### Part 3: Seq2Seq

Below, we would like to implement the forward pass in a Seq2Seq model.
This involves the following first steps:
- the `outputs` tensor is created to hold all predictions, $\hat{Y}$
- the source sequence, $X$, is fed into the encoder to receive $z$ and $H$
- the initial decoder hidden state is set to be the `context` vector, $s_0 = z = h_T$
- we use a batch of `<sos>` tokens as the first `tgt_input`, $y_1$

We ask you to implement the decoding part of the forward pass in a Seq2Seq model.
- we decode within a loop:
  - inserting the input token $y_t$, previous hidden state, $s_{t-1}$, and all encoder outputs, $H$, into the decoder
  - receiving a prediction, $\hat{y}_{t+1}$, and a new hidden state, $s_t$
  - we then decide if we are going to teacher force or not, setting the next input as appropriate
  
You might not be familiar with the term **teacher forcing**, as it was not covered in the lecture. Teacher forcing means that the decoder receives the ground truth output from the previous step as input to the next step, as opposed to receiving its own previous prediction as input. This enables the model to learn faster, as the inputs to the decoder are less noisy. However, only providing ground truth inputs might cause a discrepancy between training- and test time, which can lead to poor performance. It is therefore often good to balance both practices during training, which is controlled by the teacher forcing ratio hyperparameter. At each step of decoding, you determine at random whether the model receives the ground truth input or its own previous output (called curriculum learning). More information on teacher forcing can be found in [this blog post](https://machinelearningmastery.com/teacher-forcing-for-recurrent-neural-networks/). 

In [20]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
    def forward(self, src, trg, teacher_forcing_ratio = 0.5):
        
        #src = [src sent len, batch size]
        #trg = [trg sent len, batch size]
        #teacher_forcing_ratio is probability to use teacher forcing
        #e.g. if teacher_forcing_ratio is 0.75 we use teacher forcing 75% of the time
        
        batch_size = src.shape[1]
        max_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
        
        #tensor to store decoder outputs
        outputs = torch.zeros(max_len, batch_size, trg_vocab_size).to(self.device)
        
        #encoder_outputs is all hidden states of the input sequence, back and forwards
        #hidden is the final forward and backward hidden states, passed through a linear layer
        encoder_outputs, hidden = self.encoder(src)
                
        #first input to the decoder is the <sos> tokens
        input = trg[0,:]
        
        for t in range(1, max_len):
            # TODO: Implement the loop!
            output, hidden = self.decoder(input, hidden, encoder_outputs)
            
            outputs[t] = output
            
            teacher_force = random.random() < teacher_forcing_ratio
            
            top1 = output.argmax(1)
            
            input = trg[t] if teacher_force else top1
        
        return outputs

## Training the Seq2Seq Model

The rest of this notebook deals with initializing and training the model.

We initialise our parameters, encoder, decoder and seq2seq model (placing it on the GPU if we have one). 

In [21]:
def build_model(attention,
                input_dim,
                output_dim,
                enc_emb_dim,
                dec_emb_dim,
                enc_hid_dim,
                dec_hid_dim,
                enc_dropout, 
                dec_dropout):

    attn = attention(enc_hid_dim, dec_hid_dim)
    enc = Encoder(input_dim, enc_emb_dim, enc_hid_dim, dec_hid_dim, enc_dropout)
    dec = Decoder(output_dim, dec_emb_dim, enc_hid_dim, dec_hid_dim, dec_dropout, attn)

    model = Seq2Seq(enc, dec, device).to(device)
    
    def init_weights(m):
        for name, param in m.named_parameters():
            if 'weight' in name:
                nn.init.normal_(param.data, mean=0, std=0.01)
            else:
                nn.init.constant_(param.data, 0)
            
    model.apply(init_weights)
    return model

We use a simplified version of the weight initialization scheme used in the paper. Here, we will initialize all biases to zero and all weights from $\mathcal{N}(0, 0.01)$.

We then create the training loop...

In [22]:
def train(model, iterator, optimizer, criterion, clip):
    
    model.train()
    
    epoch_loss = 0
    
    for i, batch in enumerate(iterator):
        
        src = batch[0]
        trg = batch[1]
        
        optimizer.zero_grad()
        
        output = model(src, trg)
        
        #trg = [trg sent len, batch size]
        #output = [trg sent len, batch size, output dim]
        
        output = output[1:].view(-1, output.shape[-1])
        trg = trg[1:].view(-1)
        
        #trg = [(trg sent len - 1) * batch size]
        #output = [(trg sent len - 1) * batch size, output dim]
        
        loss = criterion(output, trg)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator.dataset)

...and the evaluation loop, remembering to set the model to `eval` mode and turn off teaching forcing.

In [23]:
def evaluate(model, iterator, criterion):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for i, batch in enumerate(iterator):
            
            src = batch[0]
            trg = batch[1]
            
            output = model(src, trg, 0) #turn off teacher forcing

            #trg = [trg sent len, batch size]
            #output = [trg sent len, batch size, output dim]

            output = output[1:].view(-1, output.shape[-1])
            trg = trg[1:].view(-1)

            #trg = [(trg sent len - 1) * batch size]
            #output = [(trg sent len - 1) * batch size, output dim]

            loss = criterion(output, trg)

            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator.dataset)

Finally, define a timing function.

In [24]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

Then, we train our model, saving the parameters that give us the best validation loss.

In [25]:
def train_epochs(model, train_iterator, optimizer, criterion, n_epochs):
    best_valid_loss = float('inf')

    for epoch in range(n_epochs):

        start_time = time.time()

        CLIP = 1
        train_loss = train(model, train_dataloader, optimizer, criterion, CLIP)
        valid_loss = evaluate(model, valid_dataloader, criterion)

        end_time = time.time()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), 'tut3-model.pt')

        print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
        print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

Lets see how the different variants of attentions compare against each other.

In [26]:
def train_model_with_attention(attention):
    INPUT_DIM = len(src_vocab)
    OUTPUT_DIM = len(tgt_vocab)
    ATTENTION = attention
    ENC_EMB_DIM = 32
    DEC_EMB_DIM = 32
    ENC_HID_DIM = 128
    DEC_HID_DIM = 128
    ENC_DROPOUT = 0.
    DEC_DROPOUT = 0.
    N_EPOCHS = 10
    PAD_IDX = tgt_vocab['<pad>']
    criterion = nn.CrossEntropyLoss(ignore_index = PAD_IDX)
    model = build_model(ATTENTION,
                        INPUT_DIM,
                    OUTPUT_DIM,
                    ENC_EMB_DIM,
                    DEC_EMB_DIM,
                    ENC_HID_DIM,
                    DEC_HID_DIM,
                    ENC_DROPOUT, 
                    DEC_DROPOUT)
    optimizer = optim.Adam(model.parameters())

    train_epochs(model, train_dataloader, optimizer, criterion, N_EPOCHS)
    
print("Additive attention")
print("========")
train_model_with_attention(AdditiveAttention)
print("========")

print("Multiplicative attention")
print("========")
train_model_with_attention(MultiplicativeAttention)
print("========")

print("Self- attention")
print("========")
train_model_with_attention(SelfAttention)
print("========")

Additive attention
Epoch: 01 | Time: 1m 25s
	Train Loss: 0.921 | Train PPL:   2.513
	 Val. Loss: 0.728 |  Val. PPL:   2.071
Epoch: 02 | Time: 1m 31s
	Train Loss: 0.663 | Train PPL:   1.941
	 Val. Loss: 0.736 |  Val. PPL:   2.087
Epoch: 03 | Time: 1m 37s
	Train Loss: 0.639 | Train PPL:   1.895
	 Val. Loss: 0.739 |  Val. PPL:   2.095
Epoch: 04 | Time: 1m 26s
	Train Loss: 0.626 | Train PPL:   1.869
	 Val. Loss: 0.735 |  Val. PPL:   2.085
Epoch: 05 | Time: 1m 27s
	Train Loss: 0.612 | Train PPL:   1.844
	 Val. Loss: 0.732 |  Val. PPL:   2.080
Epoch: 06 | Time: 1m 27s
	Train Loss: 0.601 | Train PPL:   1.824
	 Val. Loss: 0.729 |  Val. PPL:   2.073
Epoch: 07 | Time: 1m 26s
	Train Loss: 0.590 | Train PPL:   1.804
	 Val. Loss: 0.729 |  Val. PPL:   2.074
Epoch: 08 | Time: 1m 26s
	Train Loss: 0.581 | Train PPL:   1.788
	 Val. Loss: 0.736 |  Val. PPL:   2.088
Epoch: 09 | Time: 1m 25s
	Train Loss: 0.575 | Train PPL:   1.777
	 Val. Loss: 0.741 |  Val. PPL:   2.098
Epoch: 10 | Time: 1m 25s
	Train Loss

# More applications for attention
The core idea behind the attention mechanism can be described as controlling the flow of information depending on some context. In this sense, it is similar to the gating mechanism in LSTMs, only that the sum of the gate activations must some to one. This idea has wide applications across ML and NLP. [This blogpost summarizes some of the important applications](https://medium.com/@joealato/attention-in-nlp-734c6fa9d983).

Besides Sequence2Sequence models and self-attention, there are several other applications:
* [Transformers](https://arxiv.org/abs/1706.03762)
* [Memory networks](https://arxiv.org/abs/1410.3916)
* [Differentiable Neural Computer](https://www.nature.com/articles/nature20101)
* [Attention between a Sentence Pair](https://arxiv.org/abs/1509.06664)
* [Hierarchical Attention for Text Classification](https://www.aclweb.org/anthology/N16-1174/)
* [Pointer Networks](https://arxiv.org/abs/1506.03134) and [Neural Combinatorial Optimization with RL](https://arxiv.org/pdf/1611.09940.pdf)
* [Integrating Knowledge Bases](https://arxiv.org/pdf/1902.09091.pdf)
* [Unsupervised Object Discovery](https://arxiv.org/pdf/2006.15055.pdf)
* [Perceiver: General Perception with Iterative Attention](https://arxiv.org/pdf/2103.03206.pdf)